{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "from fastai.vision import *\n",
    "import torch\n",
    "import numpy as np\n",
    "torch.cuda.set_device(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = untar_data(URLs.IMAGENETTE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "tfms = get_transforms(do_flip=False)\n",
    "data = ImageDataBunch.from_folder(path, train = 'train', valid = 'val', bs = 64, size = 224, ds_tfms = tfms).normalize(imagenet_stats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential\n",
       "======================================================================\n",
       "Layer (type)         Output Shape         Param #    Trainable \n",
       "======================================================================\n",
       "Conv2d               [64, 112, 112]       9,408      False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [64, 112, 112]       128        True      \n",
       "______________________________________________________________________\n",
       "ReLU                 [64, 112, 112]       0          False     \n",
       "______________________________________________________________________\n",
       "MaxPool2d            [64, 56, 56]         0          False     \n",
       "______________________________________________________________________\n",
       "Conv2d               [64, 56, 56]         36,864     False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [64, 56, 56]         128        True      \n",
       "______________________________________________________________________\n",
       "ReLU                 [64, 56, 56]         0          False     \n",
       "______________________________________________________________________\n",
       "Conv2d               [64, 56, 56]         36,864     False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [64, 56, 56]         128        True      \n",
       "______________________________________________________________________\n",
       "Conv2d               [64, 56, 56]         36,864     False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [64, 56, 56]         128        True      \n",
       "______________________________________________________________________\n",
       "ReLU                 [64, 56, 56]         0          False     \n",
       "______________________________________________________________________\n",
       "Conv2d               [64, 56, 56]         36,864     False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [64, 56, 56]         128        True      \n",
       "______________________________________________________________________\n",
       "Conv2d               [64, 56, 56]         36,864     False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [64, 56, 56]         128        True      \n",
       "______________________________________________________________________\n",
       "ReLU                 [64, 56, 56]         0          False     \n",
       "______________________________________________________________________\n",
       "Conv2d               [64, 56, 56]         36,864     False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [64, 56, 56]         128        True      \n",
       "______________________________________________________________________\n",
       "Conv2d               [128, 28, 28]        73,728     False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [128, 28, 28]        256        True      \n",
       "______________________________________________________________________\n",
       "ReLU                 [128, 28, 28]        0          False     \n",
       "______________________________________________________________________\n",
       "Conv2d               [128, 28, 28]        147,456    False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [128, 28, 28]        256        True      \n",
       "______________________________________________________________________\n",
       "Conv2d               [128, 28, 28]        8,192      False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [128, 28, 28]        256        True      \n",
       "______________________________________________________________________\n",
       "Conv2d               [128, 28, 28]        147,456    False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [128, 28, 28]        256        True      \n",
       "______________________________________________________________________\n",
       "ReLU                 [128, 28, 28]        0          False     \n",
       "______________________________________________________________________\n",
       "Conv2d               [128, 28, 28]        147,456    False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [128, 28, 28]        256        True      \n",
       "______________________________________________________________________\n",
       "Conv2d               [128, 28, 28]        147,456    False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [128, 28, 28]        256        True      \n",
       "______________________________________________________________________\n",
       "ReLU                 [128, 28, 28]        0          False     \n",
       "______________________________________________________________________\n",
       "Conv2d               [128, 28, 28]        147,456    False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [128, 28, 28]        256        True      \n",
       "______________________________________________________________________\n",
       "Conv2d               [128, 28, 28]        147,456    False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [128, 28, 28]        256        True      \n",
       "______________________________________________________________________\n",
       "ReLU                 [128, 28, 28]        0          False     \n",
       "______________________________________________________________________\n",
       "Conv2d               [128, 28, 28]        147,456    False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [128, 28, 28]        256        True      \n",
       "______________________________________________________________________\n",
       "Conv2d               [256, 14, 14]        294,912    False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [256, 14, 14]        512        True      \n",
       "______________________________________________________________________\n",
       "ReLU                 [256, 14, 14]        0          False     \n",
       "______________________________________________________________________\n",
       "Conv2d               [256, 14, 14]        589,824    False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [256, 14, 14]        512        True      \n",
       "______________________________________________________________________\n",
       "Conv2d               [256, 14, 14]        32,768     False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [256, 14, 14]        512        True      \n",
       "______________________________________________________________________\n",
       "Conv2d               [256, 14, 14]        589,824    False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [256, 14, 14]        512        True      \n",
       "______________________________________________________________________\n",
       "ReLU                 [256, 14, 14]        0          False     \n",
       "______________________________________________________________________\n",
       "Conv2d               [256, 14, 14]        589,824    False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [256, 14, 14]        512        True      \n",
       "______________________________________________________________________\n",
       "Conv2d               [256, 14, 14]        589,824    False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [256, 14, 14]        512        True      \n",
       "______________________________________________________________________\n",
       "ReLU                 [256, 14, 14]        0          False     \n",
       "______________________________________________________________________\n",
       "Conv2d               [256, 14, 14]        589,824    False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [256, 14, 14]        512        True      \n",
       "______________________________________________________________________\n",
       "Conv2d               [256, 14, 14]        589,824    False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [256, 14, 14]        512        True      \n",
       "______________________________________________________________________\n",
       "ReLU                 [256, 14, 14]        0          False     \n",
       "______________________________________________________________________\n",
       "Conv2d               [256, 14, 14]        589,824    False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [256, 14, 14]        512        True      \n",
       "______________________________________________________________________\n",
       "Conv2d               [256, 14, 14]        589,824    False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [256, 14, 14]        512        True      \n",
       "______________________________________________________________________\n",
       "ReLU                 [256, 14, 14]        0          False     \n",
       "______________________________________________________________________\n",
       "Conv2d               [256, 14, 14]        589,824    False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [256, 14, 14]        512        True      \n",
       "______________________________________________________________________\n",
       "Conv2d               [256, 14, 14]        589,824    False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [256, 14, 14]        512        True      \n",
       "______________________________________________________________________\n",
       "ReLU                 [256, 14, 14]        0          False     \n",
       "______________________________________________________________________\n",
       "Conv2d               [256, 14, 14]        589,824    False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [256, 14, 14]        512        True      \n",
       "______________________________________________________________________\n",
       "Conv2d               [512, 7, 7]          1,179,648  False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [512, 7, 7]          1,024      True      \n",
       "______________________________________________________________________\n",
       "ReLU                 [512, 7, 7]          0          False     \n",
       "______________________________________________________________________\n",
       "Conv2d               [512, 7, 7]          2,359,296  False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [512, 7, 7]          1,024      True      \n",
       "______________________________________________________________________\n",
       "Conv2d               [512, 7, 7]          131,072    False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [512, 7, 7]          1,024      True      \n",
       "______________________________________________________________________\n",
       "Conv2d               [512, 7, 7]          2,359,296  False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [512, 7, 7]          1,024      True      \n",
       "______________________________________________________________________\n",
       "ReLU                 [512, 7, 7]          0          False     \n",
       "______________________________________________________________________\n",
       "Conv2d               [512, 7, 7]          2,359,296  False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [512, 7, 7]          1,024      True      \n",
       "______________________________________________________________________\n",
       "Conv2d               [512, 7, 7]          2,359,296  False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [512, 7, 7]          1,024      True      \n",
       "______________________________________________________________________\n",
       "ReLU                 [512, 7, 7]          0          False     \n",
       "______________________________________________________________________\n",
       "Conv2d               [512, 7, 7]          2,359,296  False     \n",
       "______________________________________________________________________\n",
       "BatchNorm2d          [512, 7, 7]          1,024      True      \n",
       "______________________________________________________________________\n",
       "AdaptiveAvgPool2d    [512, 1, 1]          0          False     \n",
       "______________________________________________________________________\n",
       "AdaptiveMaxPool2d    [512, 1, 1]          0          False     \n",
       "______________________________________________________________________\n",
       "Flatten              [1024]               0          False     \n",
       "______________________________________________________________________\n",
       "BatchNorm1d          [1024]               2,048      True      \n",
       "______________________________________________________________________\n",
       "Dropout              [1024]               0          False     \n",
       "______________________________________________________________________\n",
       "Linear               [512]                524,800    True      \n",
       "______________________________________________________________________\n",
       "ReLU                 [512]                0          False     \n",
       "______________________________________________________________________\n",
       "BatchNorm1d          [512]                1,024      True      \n",
       "______________________________________________________________________\n",
       "Dropout              [512]                0          False     \n",
       "______________________________________________________________________\n",
       "Linear               [10]                 5,130      True      \n",
       "______________________________________________________________________\n",
       "\n",
       "Total params: 21,817,674\n",
       "Total trainable params: 550,026\n",
       "Total non-trainable params: 21,267,648\n",
       "Optimized with 'torch.optim.adam.Adam', betas=(0.9, 0.99)\n",
       "Using true weight decay as discussed in https://www.fast.ai/2018/07/02/adam-weight-decay/ \n",
       "Loss function : FlattenedLoss\n",
       "======================================================================\n",
       "Callbacks functions applied "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn = cnn_learner(data, models.resnet34, metrics = accuracy)\n",
    "learn = learn.load('imagenet_bs64')\n",
    "learn.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SaveFeatures :\n",
    "    def __init__(self, m) : \n",
    "        self.handle = m.register_forward_hook(self.hook_fn)\n",
    "    def hook_fn(self, m, inp, outp) : \n",
    "        self.features = outp\n",
    "    def remove(self) :\n",
    "        self.handle.remove()\n",
    "        \n",
    "# saving outputs of all Basic Blocks\n",
    "net = learn.model\n",
    "sf = [SaveFeatures(m) for m in [net[0][2], net[0][4], net[0][5], net[0][6], net[0][7]]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
       "  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (2): ReLU(inplace)\n",
       "  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "  (4): Sequential(\n",
       "    (0): BasicBlock(\n",
       "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (1): BasicBlock(\n",
       "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (2): BasicBlock(\n",
       "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "  )\n",
       "  (5): Sequential(\n",
       "    (0): BasicBlock(\n",
       "      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (downsample): Sequential(\n",
       "        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (1): BasicBlock(\n",
       "      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (2): BasicBlock(\n",
       "      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (3): BasicBlock(\n",
       "      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "  )\n",
       "  (6): Sequential(\n",
       "    (0): BasicBlock(\n",
       "      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (downsample): Sequential(\n",
       "        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (1): BasicBlock(\n",
       "      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (2): BasicBlock(\n",
       "      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (3): BasicBlock(\n",
       "      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (4): BasicBlock(\n",
       "      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (5): BasicBlock(\n",
       "      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "  )\n",
       "  (7): Sequential(\n",
       "    (0): BasicBlock(\n",
       "      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (downsample): Sequential(\n",
       "        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (1): BasicBlock(\n",
       "      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (2): BasicBlock(\n",
       "      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = learn.model\n",
    "net[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Checking for one image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sample data\n",
    "x, y = next(iter(data.train_dl))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 3, 224, 224])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "img1 = x[None, 1]\n",
    "img1.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4.16 ms ± 1.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
     ]
    }
   ],
   "source": [
    "%timeit net(torch.autograd.Variable(img1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[torch.Size([1, 64, 56, 56]),\n",
       " torch.Size([1, 64, 56, 56]),\n",
       " torch.Size([1, 128, 28, 28]),\n",
       " torch.Size([1, 256, 14, 14]),\n",
       " torch.Size([1, 512, 7, 7])]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[o.features.shape for o in sf]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Checking for entire batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The slowest run took 13.41 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
      "13.8 ms ± 14.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
     ]
    }
   ],
   "source": [
    "%timeit net(torch.autograd.Variable(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[torch.Size([64, 64, 56, 56]),\n",
       " torch.Size([64, 128, 28, 28]),\n",
       " torch.Size([64, 256, 14, 14]),\n",
       " torch.Size([64, 512, 7, 7])]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[o.features.shape for o in sf]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### checking if np.save works properly (not using, taking the features directly at runtime)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 0\n",
    "for m in sf : \n",
    "    # converting to numpy array\n",
    "    fm = (m.features).cpu().detach().numpy()\n",
    "    filename = 'numpyarray' + str(n) + '.npy'\n",
    "    np.save(filename, fm)\n",
    "    n += 1\n",
    "    \n",
    "for i in range(4) : \n",
    "    filename = 'numpyarray' + str(i) + '.npy'\n",
    "    fm = np.load(filename)\n",
    "    assert(np.array_equal(fm, sf[i].features.cpu().detach().numpy()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (ak_fastai)",
   "language": "python",
   "name": "ak_fastai"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
